"""
Author: xuzhougeng
Affiliation: SIPPE
Aim: A snakemake workflow to mapping causual gene in forward genetics screening
Date: 2018-2-06
Version 0.1.1: Calling snp at all chromosomes simultaneously
Version 0.1.2: Add fastp as quality control process
Run: snakemake -s Snakefile

requirement:
- fastp
- bwa
- samtools
- bcftools
"""

# Globals-----------------------------------------------------------------------
from os.path import join

## reference file configuration
configfile: "config.yaml"
FASTA    = config['FASTA']
INDEX    = config['INDEX']
CHROMS   = [chrom[1:].strip() for chrom in open(FASTA, 'r') if chrom.startswith(">")]
PLOIDY   = int(config['POOLED']) * 2 
GATK     = config['GATK']
OPTIONS  = '--java-options "-Xmx8G -Djava.io.tmpdir=./"'
snpeff   = config['snpeff']
SnpEffDB = config['SnpEffDB']

ADAPTER  = ""

# make workding directory
RAW_DIR   = join("analysis", "00-raw-data")
CLEAN_DIR = join("analysis", "01-clean-data")
ALIGN_DIR = join("analysis", "02-read-alignment")
VCF_DIR   = join("analysis", "03-variant-calling")
TMPDIR    = join("analysis", "tmp")
LOGDIR    = join("analysis", "log")
QC_DIR    = join("report", "fastp")
REPDIR    = join("report", "snpeff")

shell("mkdir -p {CLEAN_DIR} {ALIGN_DIR} {VCF_DIR} {TMPDIR} {LOGDIR} {QC_DIR} {REPDIR}")

## The list of samples to be processed
SAMPLES  = [config['wildtype'], config['mutation']]
ALL_BAM  = expand(join(ALIGN_DIR, "{sample}_markdup_sorted.bam"), sample=SAMPLES)
ALL_VCF  = expand(join(VCF_DIR, "_".join(SAMPLES) + "_{chrom}.vcf"), chrom=CHROMS)
RAW_VCF  = join(VCF_DIR, "_".join(SAMPLES) + '_gatk_hc.vcf')
FLT_VCF  = join(VCF_DIR, "_".join(SAMPLES) + '_flt.vcf' )
ANN_VCF  = join(VCF_DIR, "_".join(SAMPLES) + '_ann.vcf.gz' )
ALL_CSV  = join("report", "all_gene.csv")
CAN_CSV  = join("report", "causual_gene.csv")

# Rules------------------------------------------------------------------------

rule all:
    input: ALL_BAM, ALL_VCF, RAW_VCF, FLT_VCF, ANN_VCF, ALL_CSV, CAN_CSV

rule data_clean:
    input:
        r1 = join(RAW_DIR, "{sample}_1.fq.gz"),
        r2 = join(RAW_DIR, "{sample}_2.fq.gz")
    params:
        adapter = ADAPTER,
        length  = 50,
        quality = 20,
        prefix  = lambda wildcards: join(QC_DIR, wildcards.sample)
    output:
        r1 = join(CLEAN_DIR, "{sample}_1.fq.gz"),
        r2 = join(CLEAN_DIR, "{sample}_2.fq.gz")
    log: join(LOGDIR, "{sample}_fastp.log")
    threads: 4
    message: "----- processing {input.r1} and {input.r2} with fastp ------"
    shell: """
        fastp $(if [ -n "{params.adapter}" ]; then echo '-a {params.adapter}'; fi;) \
        -w {threads} -j {params.prefix}.json -h {params.prefix}.html -l {params.length} -q {params.quality}\
        -i {input.r1} -I {input.r2} -o {output.r1} -O {output.r2}
    """


rule bwa_mem:
    input:
        r1 = join(CLEAN_DIR, "{sample}_1.fq.gz"),
        r2 = join(CLEAN_DIR, "{sample}_2.fq.gz")
    params:
        index  = INDEX,
        rg     = r"@RG\tID:{sample}\tSM:{sample}\tPL:ILLUMINA",
        outdir = ALIGN_DIR,
        tmpdir = TMPDIR,
        logdir = LOGDIR
    output:
        join(ALIGN_DIR, "{sample}_markdup_sorted.bam"),
        join(ALIGN_DIR, "{sample}_split.sam"),
        join(ALIGN_DIR, "{sample}_disc.sam")
    log:
        log1=join(LOGDIR, "{sample}_align_warning.log"),
        log2=join(LOGDIR, "{sample}_samblaster.log")
    threads: 8
    message: "align {input.r1} and {input.r2} to {params.index} with {threads} threads"
    shell:"""
        bwa mem -v 2 -t {threads} -R '{params.rg}' {params.index} {input.r1} {input.r2} 2> {log.log1} |\
            samblaster --addMateTags --splitterFile {output[1]} --discordantFile {output[2]} 2> {log.log2} |\
            sambamba view -S -f bam -l 0 /dev/stdin |\
            sambamba sort -t {threads} -m 2G --tmpdir {params.tmpdir} -o {output[0]} /dev/stdin
        """

rule SeperateHaplotypeCaller:
    input: ALL_BAM
    params:
        genome = FASTA,
        ploidy = PLOIDY,
        call_conf = 30
    output: 
        vcf = join(VCF_DIR, "_".join(SAMPLES) + "_{chrom}.vcf"),
        idx = join(VCF_DIR, "_".join(SAMPLES) + "_{chrom}.vcf.idx")
    message: "Calling variants with GATK4/HaplotyeCaller"
    shell:"""
    {GATK} {OPTIONS} HaplotypeCaller \
        -R {params.genome} \
        --intervals {wildcards.chrom} \
        $(for bam in {input}; do echo "-I $bam"; done) \
        --genotyping-mode DISCOVERY \
        -stand-call-conf {params.call_conf} \
        --sample-ploidy {params.ploidy} \
        --assembly-region-out {VCF_DIR}/{wildcards.chrom}.bed \
        --bam-output {VCF_DIR}/{wildcards.chrom}.bam \
        -O {output.vcf}
    """

rule MergeVcfs:
    input: 
        expand(join(VCF_DIR, "_".join(SAMPLES) + "_{chrom}.vcf"), chrom=CHROMS)
    output:
        vcf = join(VCF_DIR, "_".join(SAMPLES) + '_gatk_hc.vcf'),
    shell:"""
    {GATK} {OPTIONS} MergeVcfs \
        $(for vcf in {input}; do echo "-I $vcf"; done) \
        -O {output.vcf}
    """

# snp filteration
import math
average_depth = config['depth']
depth_up      = math.ceil(2 * (average_depth + 4.0 * math.sqrt(average_depth)))
depth_down    = math.floor(2 * math.sqrt(average_depth))

rule snp_filter:
    input: join(VCF_DIR, "_".join(SAMPLES) + '_gatk_hc.vcf')
    params:
        genome = FASTA,
        depth_up = depth_up,
        depth_down = depth_down,
        qual_cutoff   = 2.0 * depth_up
    output:
        vcf   = join("analysis", "03-variant-calling","_".join(SAMPLES) + '_flt.vcf'),
        index = join("analysis", "03-variant-calling","_".join(SAMPLES) + '_flt.vcf.idx')
    message: "filter snp with SelectVariants(GATK4) and VariantFiltration(GATK4)"
    shell:"""
    {GATK} {OPTIONS} SelectVariants -R {params.genome} -V {input} \
        -select-type SNP -O temp_snps.vcf
    {GATK} {OPTIONS} VariantFiltration -R {params.genome} -V temp_snps.vcf \
    -cluster 2 -window 100 \
    --filter-expression "QD < 2.0 || FS > 60.0 || MQ < 40.0 || MQRankSum < -12.5 || ReadPosRankSum < -8.0" \
    --filter-name "gatk_snp_filter" \
    --filter-expression "DP > {params.depth_up} && QUAL < {params.qual_cutoff}" \
    --filter-name "high_depth_filter" \
    -O {output.vcf}
    rm -f temp_snps.vcf temp_snps.vcf.idx
    """

rule snpeff:
    input:
        join(VCF_DIR, "_".join(SAMPLES) + '_flt.vcf' )
    params:
        db = SnpEffDB,
        outdir = "report/snpeff"
    output:
       vcf   = join(VCF_DIR, "_".join(SAMPLES) + '_ann' + '.vcf.gz'),
       index = join(VCF_DIR, "_".join(SAMPLES) + '_ann' + '.vcf.gz.csi')
    message: "annotate with SnpEff"
    shell:"""
        mkdir -p {params.outdir}
        java -Xmx4G -jar {snpeff} ann -o gatk -s {params.outdir}/snpeff_summary.html {params.db} {input} | \
        bcftools view  -Oz -o {output.vcf}
        bcftools index {output.vcf}
        """

rule vcf_stats:
    input:
        join(VCF_DIR, "{sample}.vcf.gz")
    output:
        join(VCF_DIR, "{sample}.stat.txt")
    shell:
        "bcftools stats {input} > {output}"

rule vcf2csv:
    input:
        join(VCF_DIR,"_".join(SAMPLES) + '_ann' + '.vcf.gz')
    output:
        join("report","all_gene.csv")
    message: "formatting the vcf with bcftools query"
    shell:"""
    bcftools view -e 'GT~"\."' {input} | bcftools view -e 'GT="AA"' |\
    bcftools query -i 'FILTER="PASS"' -H -f "%CHROM,%POS[,%GT,%AD{{0}},%AD{{1}}],%EFF\n" > {output}
    sed -i '1s/# //' {output}
    sed -i '1s/\[.\]//g' {output}
    """

rule result:
    input:
        join("report","all_gene.csv")
    params:
        wildtype = config['wildtype'],
        mutation = config['mutation'],
        threshold= config['threshold']
    output:
        csv = join("report", "causual_gene.csv"),
        pdf = join("report", "snp_trend_plot.pdf")
    script:
        "script/snp_filter.R"
